{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GDP 预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pandas_datareader import wb\n",
    "\n",
    "import torch\n",
    "import torch.nn\n",
    "import torch.optim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>country</th>\n",
       "      <th>Brazil</th>\n",
       "      <th>Canada</th>\n",
       "      <th>China</th>\n",
       "      <th>France</th>\n",
       "      <th>Germany</th>\n",
       "      <th>India</th>\n",
       "      <th>Israel</th>\n",
       "      <th>Japan</th>\n",
       "      <th>Saudi Arabia</th>\n",
       "      <th>United Kingdom</th>\n",
       "      <th>United States</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>year</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1970</th>\n",
       "      <td>4706.126393</td>\n",
       "      <td>24629.215564</td>\n",
       "      <td>228.317703</td>\n",
       "      <td>20090.770923</td>\n",
       "      <td>19624.749759</td>\n",
       "      <td>365.057383</td>\n",
       "      <td>14476.725344</td>\n",
       "      <td>18435.455076</td>\n",
       "      <td>22133.904924</td>\n",
       "      <td>17934.191423</td>\n",
       "      <td>23309.620946</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1971</th>\n",
       "      <td>5108.945626</td>\n",
       "      <td>25262.441350</td>\n",
       "      <td>237.813838</td>\n",
       "      <td>20985.170602</td>\n",
       "      <td>20202.433743</td>\n",
       "      <td>362.767725</td>\n",
       "      <td>14750.563777</td>\n",
       "      <td>19054.841724</td>\n",
       "      <td>25517.184135</td>\n",
       "      <td>18481.731208</td>\n",
       "      <td>23775.276923</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1972</th>\n",
       "      <td>5586.683824</td>\n",
       "      <td>26216.591277</td>\n",
       "      <td>240.881889</td>\n",
       "      <td>21739.884077</td>\n",
       "      <td>20970.626433</td>\n",
       "      <td>352.550056</td>\n",
       "      <td>16127.116380</td>\n",
       "      <td>20370.673766</td>\n",
       "      <td>29931.470715</td>\n",
       "      <td>19211.556510</td>\n",
       "      <td>24760.145377</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1973</th>\n",
       "      <td>6216.130816</td>\n",
       "      <td>27571.292534</td>\n",
       "      <td>253.714373</td>\n",
       "      <td>22903.302398</td>\n",
       "      <td>21903.403507</td>\n",
       "      <td>355.788210</td>\n",
       "      <td>16352.902703</td>\n",
       "      <td>21825.543720</td>\n",
       "      <td>35393.583566</td>\n",
       "      <td>20422.489067</td>\n",
       "      <td>25908.912802</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1974</th>\n",
       "      <td>6617.883590</td>\n",
       "      <td>28080.940720</td>\n",
       "      <td>254.267485</td>\n",
       "      <td>23690.600278</td>\n",
       "      <td>22089.748966</td>\n",
       "      <td>351.708069</td>\n",
       "      <td>16901.495224</td>\n",
       "      <td>21150.496237</td>\n",
       "      <td>39125.445624</td>\n",
       "      <td>19906.842338</td>\n",
       "      <td>25540.501003</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1975</th>\n",
       "      <td>6798.096714</td>\n",
       "      <td>28057.047887</td>\n",
       "      <td>271.599476</td>\n",
       "      <td>23298.340492</td>\n",
       "      <td>21980.087890</td>\n",
       "      <td>375.083373</td>\n",
       "      <td>17055.638892</td>\n",
       "      <td>21458.049820</td>\n",
       "      <td>33860.391587</td>\n",
       "      <td>19613.863820</td>\n",
       "      <td>25239.919906</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1976</th>\n",
       "      <td>7287.519095</td>\n",
       "      <td>29128.014087</td>\n",
       "      <td>263.230622</td>\n",
       "      <td>24174.746998</td>\n",
       "      <td>23167.059261</td>\n",
       "      <td>372.642613</td>\n",
       "      <td>16732.882702</td>\n",
       "      <td>22146.595979</td>\n",
       "      <td>37905.118963</td>\n",
       "      <td>20189.605653</td>\n",
       "      <td>26347.809282</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1977</th>\n",
       "      <td>7443.898645</td>\n",
       "      <td>29783.267924</td>\n",
       "      <td>279.324547</td>\n",
       "      <td>24907.284189</td>\n",
       "      <td>23996.772912</td>\n",
       "      <td>390.636680</td>\n",
       "      <td>16531.778366</td>\n",
       "      <td>22897.185145</td>\n",
       "      <td>38557.214887</td>\n",
       "      <td>20689.657017</td>\n",
       "      <td>27286.251514</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1978</th>\n",
       "      <td>7504.296020</td>\n",
       "      <td>30651.632553</td>\n",
       "      <td>307.766195</td>\n",
       "      <td>25811.802598</td>\n",
       "      <td>24740.236527</td>\n",
       "      <td>403.633544</td>\n",
       "      <td>17085.997506</td>\n",
       "      <td>23887.179963</td>\n",
       "      <td>34657.430841</td>\n",
       "      <td>21557.849741</td>\n",
       "      <td>28500.240457</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1979</th>\n",
       "      <td>7824.864356</td>\n",
       "      <td>31502.044296</td>\n",
       "      <td>326.768369</td>\n",
       "      <td>26641.929481</td>\n",
       "      <td>25755.657807</td>\n",
       "      <td>373.832253</td>\n",
       "      <td>17686.527449</td>\n",
       "      <td>24985.791175</td>\n",
       "      <td>36676.579817</td>\n",
       "      <td>22344.368377</td>\n",
       "      <td>29082.593778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1980</th>\n",
       "      <td>8339.269119</td>\n",
       "      <td>31769.783495</td>\n",
       "      <td>347.887413</td>\n",
       "      <td>26962.258783</td>\n",
       "      <td>26064.389314</td>\n",
       "      <td>389.926284</td>\n",
       "      <td>17555.374050</td>\n",
       "      <td>25489.166212</td>\n",
       "      <td>36518.045017</td>\n",
       "      <td>21865.128077</td>\n",
       "      <td>28734.399260</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1981</th>\n",
       "      <td>7788.333546</td>\n",
       "      <td>32477.295536</td>\n",
       "      <td>361.224711</td>\n",
       "      <td>27132.702886</td>\n",
       "      <td>26162.454627</td>\n",
       "      <td>403.878043</td>\n",
       "      <td>18128.197415</td>\n",
       "      <td>26358.347887</td>\n",
       "      <td>34979.348502</td>\n",
       "      <td>21688.996341</td>\n",
       "      <td>29191.999488</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1982</th>\n",
       "      <td>7653.559562</td>\n",
       "      <td>31060.644997</td>\n",
       "      <td>387.745581</td>\n",
       "      <td>27677.450585</td>\n",
       "      <td>26083.952145</td>\n",
       "      <td>408.324880</td>\n",
       "      <td>18161.403359</td>\n",
       "      <td>27064.101587</td>\n",
       "      <td>26017.961881</td>\n",
       "      <td>22132.998247</td>\n",
       "      <td>28362.494616</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1983</th>\n",
       "      <td>7225.538359</td>\n",
       "      <td>31549.801704</td>\n",
       "      <td>423.593499</td>\n",
       "      <td>27876.663651</td>\n",
       "      <td>26563.644936</td>\n",
       "      <td>428.069168</td>\n",
       "      <td>18433.073989</td>\n",
       "      <td>27703.018898</td>\n",
       "      <td>20512.784349</td>\n",
       "      <td>23059.570217</td>\n",
       "      <td>29406.257469</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1984</th>\n",
       "      <td>7439.110811</td>\n",
       "      <td>33099.374446</td>\n",
       "      <td>481.364596</td>\n",
       "      <td>28144.023419</td>\n",
       "      <td>27408.099813</td>\n",
       "      <td>434.366505</td>\n",
       "      <td>18363.527461</td>\n",
       "      <td>28756.637828</td>\n",
       "      <td>18427.004739</td>\n",
       "      <td>23547.164676</td>\n",
       "      <td>31268.975645</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1985</th>\n",
       "      <td>7860.015578</td>\n",
       "      <td>34345.613960</td>\n",
       "      <td>538.690827</td>\n",
       "      <td>28437.249780</td>\n",
       "      <td>28108.893013</td>\n",
       "      <td>446.999769</td>\n",
       "      <td>18770.866975</td>\n",
       "      <td>30391.524934</td>\n",
       "      <td>15734.737370</td>\n",
       "      <td>24479.425426</td>\n",
       "      <td>32306.833057</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1986</th>\n",
       "      <td>8314.878921</td>\n",
       "      <td>34737.275318</td>\n",
       "      <td>578.184040</td>\n",
       "      <td>28934.193414</td>\n",
       "      <td>28738.682612</td>\n",
       "      <td>458.086675</td>\n",
       "      <td>19257.140369</td>\n",
       "      <td>31062.093246</td>\n",
       "      <td>17509.058826</td>\n",
       "      <td>25185.830402</td>\n",
       "      <td>33133.695444</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1987</th>\n",
       "      <td>8445.036995</td>\n",
       "      <td>35689.034221</td>\n",
       "      <td>635.494603</td>\n",
       "      <td>29499.257296</td>\n",
       "      <td>29096.910185</td>\n",
       "      <td>465.982729</td>\n",
       "      <td>20301.478541</td>\n",
       "      <td>32179.290398</td>\n",
       "      <td>15608.752709</td>\n",
       "      <td>26465.995011</td>\n",
       "      <td>33975.654795</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1988</th>\n",
       "      <td>8276.542522</td>\n",
       "      <td>36791.759638</td>\n",
       "      <td>695.599054</td>\n",
       "      <td>30706.691072</td>\n",
       "      <td>30057.941478</td>\n",
       "      <td>500.013277</td>\n",
       "      <td>20548.910736</td>\n",
       "      <td>34332.258170</td>\n",
       "      <td>16921.548686</td>\n",
       "      <td>27925.081236</td>\n",
       "      <td>35083.969043</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1989</th>\n",
       "      <td>8391.192292</td>\n",
       "      <td>36981.278773</td>\n",
       "      <td>713.689528</td>\n",
       "      <td>31852.750006</td>\n",
       "      <td>30988.589460</td>\n",
       "      <td>518.698716</td>\n",
       "      <td>20319.362638</td>\n",
       "      <td>36028.153127</td>\n",
       "      <td>16194.514299</td>\n",
       "      <td>28568.126539</td>\n",
       "      <td>36033.330203</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1990</th>\n",
       "      <td>7986.045873</td>\n",
       "      <td>36489.266354</td>\n",
       "      <td>730.772489</td>\n",
       "      <td>32596.005188</td>\n",
       "      <td>32337.101023</td>\n",
       "      <td>536.162786</td>\n",
       "      <td>21141.296987</td>\n",
       "      <td>37906.163702</td>\n",
       "      <td>18002.738727</td>\n",
       "      <td>28691.292875</td>\n",
       "      <td>36312.414183</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1991</th>\n",
       "      <td>7966.796468</td>\n",
       "      <td>35231.021206</td>\n",
       "      <td>787.867435</td>\n",
       "      <td>32908.557969</td>\n",
       "      <td>33742.219217</td>\n",
       "      <td>530.894738</td>\n",
       "      <td>21444.754977</td>\n",
       "      <td>39044.927269</td>\n",
       "      <td>20040.495704</td>\n",
       "      <td>28291.921960</td>\n",
       "      <td>35803.868421</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1992</th>\n",
       "      <td>7796.841831</td>\n",
       "      <td>35108.519038</td>\n",
       "      <td>888.911004</td>\n",
       "      <td>33269.146912</td>\n",
       "      <td>34130.852398</td>\n",
       "      <td>548.895784</td>\n",
       "      <td>22323.835383</td>\n",
       "      <td>39267.120243</td>\n",
       "      <td>20226.854221</td>\n",
       "      <td>28321.034526</td>\n",
       "      <td>36566.173770</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1993</th>\n",
       "      <td>8027.189384</td>\n",
       "      <td>35648.477988</td>\n",
       "      <td>1000.611810</td>\n",
       "      <td>32922.363003</td>\n",
       "      <td>33583.006036</td>\n",
       "      <td>563.749688</td>\n",
       "      <td>22633.175722</td>\n",
       "      <td>39237.326692</td>\n",
       "      <td>19413.782933</td>\n",
       "      <td>28967.155045</td>\n",
       "      <td>37078.049684</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1994</th>\n",
       "      <td>8319.209024</td>\n",
       "      <td>36893.981886</td>\n",
       "      <td>1118.499577</td>\n",
       "      <td>33569.330886</td>\n",
       "      <td>34289.124749</td>\n",
       "      <td>589.708788</td>\n",
       "      <td>23692.851616</td>\n",
       "      <td>39441.570730</td>\n",
       "      <td>19041.058398</td>\n",
       "      <td>30014.581399</td>\n",
       "      <td>38104.972468</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1995</th>\n",
       "      <td>8547.941160</td>\n",
       "      <td>37569.468240</td>\n",
       "      <td>1227.556407</td>\n",
       "      <td>34145.705416</td>\n",
       "      <td>34782.568625</td>\n",
       "      <td>622.303683</td>\n",
       "      <td>24592.248121</td>\n",
       "      <td>40368.705109</td>\n",
       "      <td>18648.856867</td>\n",
       "      <td>30674.605547</td>\n",
       "      <td>38677.715088</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1996</th>\n",
       "      <td>8598.014959</td>\n",
       "      <td>37765.732451</td>\n",
       "      <td>1335.362680</td>\n",
       "      <td>34497.283908</td>\n",
       "      <td>34965.690886</td>\n",
       "      <td>656.697144</td>\n",
       "      <td>25230.310306</td>\n",
       "      <td>41514.862495</td>\n",
       "      <td>18744.780515</td>\n",
       "      <td>31373.341190</td>\n",
       "      <td>39681.519858</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1997</th>\n",
       "      <td>8750.274058</td>\n",
       "      <td>38967.953101</td>\n",
       "      <td>1443.774742</td>\n",
       "      <td>35178.934539</td>\n",
       "      <td>35560.209286</td>\n",
       "      <td>670.610122</td>\n",
       "      <td>25508.753756</td>\n",
       "      <td>41861.911038</td>\n",
       "      <td>18588.304345</td>\n",
       "      <td>32556.316695</td>\n",
       "      <td>40965.846645</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1998</th>\n",
       "      <td>8644.420843</td>\n",
       "      <td>40131.701834</td>\n",
       "      <td>1542.064130</td>\n",
       "      <td>36295.935347</td>\n",
       "      <td>36258.674432</td>\n",
       "      <td>699.068855</td>\n",
       "      <td>25970.833783</td>\n",
       "      <td>41277.077476</td>\n",
       "      <td>18763.585928</td>\n",
       "      <td>33480.166013</td>\n",
       "      <td>42292.891201</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1999</th>\n",
       "      <td>8554.834439</td>\n",
       "      <td>41856.045592</td>\n",
       "      <td>1645.987996</td>\n",
       "      <td>37339.980975</td>\n",
       "      <td>36955.289607</td>\n",
       "      <td>747.252036</td>\n",
       "      <td>26233.305585</td>\n",
       "      <td>41097.961196</td>\n",
       "      <td>17690.917569</td>\n",
       "      <td>34442.107095</td>\n",
       "      <td>43768.884993</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2000</th>\n",
       "      <td>8778.188141</td>\n",
       "      <td>43638.283155</td>\n",
       "      <td>1771.741506</td>\n",
       "      <td>38522.210279</td>\n",
       "      <td>37998.425312</td>\n",
       "      <td>762.313341</td>\n",
       "      <td>27636.340774</td>\n",
       "      <td>42169.697876</td>\n",
       "      <td>18263.230082</td>\n",
       "      <td>35576.766847</td>\n",
       "      <td>45055.817918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2001</th>\n",
       "      <td>8776.863719</td>\n",
       "      <td>43964.954594</td>\n",
       "      <td>1905.610780</td>\n",
       "      <td>38990.306133</td>\n",
       "      <td>38577.725629</td>\n",
       "      <td>785.344628</td>\n",
       "      <td>26999.365062</td>\n",
       "      <td>42239.113089</td>\n",
       "      <td>17585.390455</td>\n",
       "      <td>36341.709750</td>\n",
       "      <td>45047.487198</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2002</th>\n",
       "      <td>8924.333528</td>\n",
       "      <td>44883.828396</td>\n",
       "      <td>2065.718579</td>\n",
       "      <td>39140.714764</td>\n",
       "      <td>38512.920041</td>\n",
       "      <td>801.507933</td>\n",
       "      <td>26503.843351</td>\n",
       "      <td>42190.778633</td>\n",
       "      <td>16619.434659</td>\n",
       "      <td>37077.648355</td>\n",
       "      <td>45428.645678</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2003</th>\n",
       "      <td>8910.855624</td>\n",
       "      <td>45239.811391</td>\n",
       "      <td>2258.912105</td>\n",
       "      <td>39182.779389</td>\n",
       "      <td>38218.349646</td>\n",
       "      <td>850.293265</td>\n",
       "      <td>26229.399767</td>\n",
       "      <td>42743.993910</td>\n",
       "      <td>17954.949794</td>\n",
       "      <td>38132.841086</td>\n",
       "      <td>46304.036090</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2004</th>\n",
       "      <td>9309.008195</td>\n",
       "      <td>46170.920197</td>\n",
       "      <td>2472.586556</td>\n",
       "      <td>39979.119715</td>\n",
       "      <td>38673.888113</td>\n",
       "      <td>902.905794</td>\n",
       "      <td>26947.444502</td>\n",
       "      <td>43671.680357</td>\n",
       "      <td>18822.730041</td>\n",
       "      <td>38813.021694</td>\n",
       "      <td>47614.279862</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2005</th>\n",
       "      <td>9495.104939</td>\n",
       "      <td>47181.562394</td>\n",
       "      <td>2738.205460</td>\n",
       "      <td>40316.812760</td>\n",
       "      <td>38969.321698</td>\n",
       "      <td>971.229761</td>\n",
       "      <td>27570.945566</td>\n",
       "      <td>44393.662740</td>\n",
       "      <td>19309.312067</td>\n",
       "      <td>39740.902921</td>\n",
       "      <td>48755.616061</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2006</th>\n",
       "      <td>9761.876402</td>\n",
       "      <td>48035.035792</td>\n",
       "      <td>3069.304781</td>\n",
       "      <td>40987.549076</td>\n",
       "      <td>40456.857380</td>\n",
       "      <td>1044.893940</td>\n",
       "      <td>28499.325927</td>\n",
       "      <td>44995.521508</td>\n",
       "      <td>19304.550232</td>\n",
       "      <td>40418.747305</td>\n",
       "      <td>49575.401014</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2007</th>\n",
       "      <td>10245.230813</td>\n",
       "      <td>48552.696431</td>\n",
       "      <td>3487.845766</td>\n",
       "      <td>41696.692239</td>\n",
       "      <td>41831.867088</td>\n",
       "      <td>1130.090071</td>\n",
       "      <td>29614.075151</td>\n",
       "      <td>45687.345751</td>\n",
       "      <td>19136.159234</td>\n",
       "      <td>41050.405926</td>\n",
       "      <td>49979.533843</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2008</th>\n",
       "      <td>10658.225308</td>\n",
       "      <td>48510.567773</td>\n",
       "      <td>3805.025999</td>\n",
       "      <td>41545.294092</td>\n",
       "      <td>42365.097496</td>\n",
       "      <td>1156.932527</td>\n",
       "      <td>29962.005614</td>\n",
       "      <td>45165.887162</td>\n",
       "      <td>19792.720384</td>\n",
       "      <td>40536.134857</td>\n",
       "      <td>49364.644550</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2009</th>\n",
       "      <td>10540.108767</td>\n",
       "      <td>46543.792200</td>\n",
       "      <td>4142.038286</td>\n",
       "      <td>40116.379527</td>\n",
       "      <td>40086.104759</td>\n",
       "      <td>1237.339786</td>\n",
       "      <td>29658.522632</td>\n",
       "      <td>42724.534911</td>\n",
       "      <td>18861.109998</td>\n",
       "      <td>38545.915816</td>\n",
       "      <td>47575.608563</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2010</th>\n",
       "      <td>11224.154083</td>\n",
       "      <td>47447.476024</td>\n",
       "      <td>4560.512586</td>\n",
       "      <td>40703.345921</td>\n",
       "      <td>41785.556913</td>\n",
       "      <td>1345.770153</td>\n",
       "      <td>30642.940617</td>\n",
       "      <td>44507.676386</td>\n",
       "      <td>19259.587257</td>\n",
       "      <td>38893.018494</td>\n",
       "      <td>48373.878816</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2011</th>\n",
       "      <td>11559.212271</td>\n",
       "      <td>48456.964574</td>\n",
       "      <td>4971.544929</td>\n",
       "      <td>41349.191571</td>\n",
       "      <td>44125.331412</td>\n",
       "      <td>1416.403391</td>\n",
       "      <td>31482.997157</td>\n",
       "      <td>44538.726191</td>\n",
       "      <td>20575.497951</td>\n",
       "      <td>39150.756019</td>\n",
       "      <td>48783.468587</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2012</th>\n",
       "      <td>11671.182943</td>\n",
       "      <td>48724.245800</td>\n",
       "      <td>5336.060143</td>\n",
       "      <td>41224.729662</td>\n",
       "      <td>44259.259905</td>\n",
       "      <td>1474.967674</td>\n",
       "      <td>31507.478563</td>\n",
       "      <td>45276.874335</td>\n",
       "      <td>21056.347147</td>\n",
       "      <td>39455.412136</td>\n",
       "      <td>49497.585853</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2013</th>\n",
       "      <td>11912.146756</td>\n",
       "      <td>49359.422470</td>\n",
       "      <td>5721.693819</td>\n",
       "      <td>41249.449408</td>\n",
       "      <td>44354.736887</td>\n",
       "      <td>1550.142230</td>\n",
       "      <td>32196.358171</td>\n",
       "      <td>46249.209589</td>\n",
       "      <td>21005.012123</td>\n",
       "      <td>39996.501902</td>\n",
       "      <td>49976.628767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2014</th>\n",
       "      <td>11866.388922</td>\n",
       "      <td>50221.841982</td>\n",
       "      <td>6108.238775</td>\n",
       "      <td>41431.038668</td>\n",
       "      <td>45022.565349</td>\n",
       "      <td>1646.781252</td>\n",
       "      <td>32661.294135</td>\n",
       "      <td>46484.155267</td>\n",
       "      <td>21183.464888</td>\n",
       "      <td>40908.745855</td>\n",
       "      <td>50881.106863</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2015</th>\n",
       "      <td>11322.146680</td>\n",
       "      <td>50303.836848</td>\n",
       "      <td>6496.624013</td>\n",
       "      <td>41689.707979</td>\n",
       "      <td>45412.556808</td>\n",
       "      <td>1758.043376</td>\n",
       "      <td>32993.314276</td>\n",
       "      <td>47163.494211</td>\n",
       "      <td>21507.955693</td>\n",
       "      <td>41536.919115</td>\n",
       "      <td>51956.583468</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2016</th>\n",
       "      <td>10826.271435</td>\n",
       "      <td>50407.341330</td>\n",
       "      <td>6893.776361</td>\n",
       "      <td>42015.738294</td>\n",
       "      <td>45845.526542</td>\n",
       "      <td>1861.491029</td>\n",
       "      <td>33677.461945</td>\n",
       "      <td>47660.893039</td>\n",
       "      <td>21395.359780</td>\n",
       "      <td>42039.736364</td>\n",
       "      <td>52364.244025</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "country        Brazil        Canada        China        France       Germany  \\\n",
       "year                                                                           \n",
       "1970      4706.126393  24629.215564   228.317703  20090.770923  19624.749759   \n",
       "1971      5108.945626  25262.441350   237.813838  20985.170602  20202.433743   \n",
       "1972      5586.683824  26216.591277   240.881889  21739.884077  20970.626433   \n",
       "1973      6216.130816  27571.292534   253.714373  22903.302398  21903.403507   \n",
       "1974      6617.883590  28080.940720   254.267485  23690.600278  22089.748966   \n",
       "1975      6798.096714  28057.047887   271.599476  23298.340492  21980.087890   \n",
       "1976      7287.519095  29128.014087   263.230622  24174.746998  23167.059261   \n",
       "1977      7443.898645  29783.267924   279.324547  24907.284189  23996.772912   \n",
       "1978      7504.296020  30651.632553   307.766195  25811.802598  24740.236527   \n",
       "1979      7824.864356  31502.044296   326.768369  26641.929481  25755.657807   \n",
       "1980      8339.269119  31769.783495   347.887413  26962.258783  26064.389314   \n",
       "1981      7788.333546  32477.295536   361.224711  27132.702886  26162.454627   \n",
       "1982      7653.559562  31060.644997   387.745581  27677.450585  26083.952145   \n",
       "1983      7225.538359  31549.801704   423.593499  27876.663651  26563.644936   \n",
       "1984      7439.110811  33099.374446   481.364596  28144.023419  27408.099813   \n",
       "1985      7860.015578  34345.613960   538.690827  28437.249780  28108.893013   \n",
       "1986      8314.878921  34737.275318   578.184040  28934.193414  28738.682612   \n",
       "1987      8445.036995  35689.034221   635.494603  29499.257296  29096.910185   \n",
       "1988      8276.542522  36791.759638   695.599054  30706.691072  30057.941478   \n",
       "1989      8391.192292  36981.278773   713.689528  31852.750006  30988.589460   \n",
       "1990      7986.045873  36489.266354   730.772489  32596.005188  32337.101023   \n",
       "1991      7966.796468  35231.021206   787.867435  32908.557969  33742.219217   \n",
       "1992      7796.841831  35108.519038   888.911004  33269.146912  34130.852398   \n",
       "1993      8027.189384  35648.477988  1000.611810  32922.363003  33583.006036   \n",
       "1994      8319.209024  36893.981886  1118.499577  33569.330886  34289.124749   \n",
       "1995      8547.941160  37569.468240  1227.556407  34145.705416  34782.568625   \n",
       "1996      8598.014959  37765.732451  1335.362680  34497.283908  34965.690886   \n",
       "1997      8750.274058  38967.953101  1443.774742  35178.934539  35560.209286   \n",
       "1998      8644.420843  40131.701834  1542.064130  36295.935347  36258.674432   \n",
       "1999      8554.834439  41856.045592  1645.987996  37339.980975  36955.289607   \n",
       "2000      8778.188141  43638.283155  1771.741506  38522.210279  37998.425312   \n",
       "2001      8776.863719  43964.954594  1905.610780  38990.306133  38577.725629   \n",
       "2002      8924.333528  44883.828396  2065.718579  39140.714764  38512.920041   \n",
       "2003      8910.855624  45239.811391  2258.912105  39182.779389  38218.349646   \n",
       "2004      9309.008195  46170.920197  2472.586556  39979.119715  38673.888113   \n",
       "2005      9495.104939  47181.562394  2738.205460  40316.812760  38969.321698   \n",
       "2006      9761.876402  48035.035792  3069.304781  40987.549076  40456.857380   \n",
       "2007     10245.230813  48552.696431  3487.845766  41696.692239  41831.867088   \n",
       "2008     10658.225308  48510.567773  3805.025999  41545.294092  42365.097496   \n",
       "2009     10540.108767  46543.792200  4142.038286  40116.379527  40086.104759   \n",
       "2010     11224.154083  47447.476024  4560.512586  40703.345921  41785.556913   \n",
       "2011     11559.212271  48456.964574  4971.544929  41349.191571  44125.331412   \n",
       "2012     11671.182943  48724.245800  5336.060143  41224.729662  44259.259905   \n",
       "2013     11912.146756  49359.422470  5721.693819  41249.449408  44354.736887   \n",
       "2014     11866.388922  50221.841982  6108.238775  41431.038668  45022.565349   \n",
       "2015     11322.146680  50303.836848  6496.624013  41689.707979  45412.556808   \n",
       "2016     10826.271435  50407.341330  6893.776361  42015.738294  45845.526542   \n",
       "\n",
       "country        India        Israel         Japan  Saudi Arabia  \\\n",
       "year                                                             \n",
       "1970      365.057383  14476.725344  18435.455076  22133.904924   \n",
       "1971      362.767725  14750.563777  19054.841724  25517.184135   \n",
       "1972      352.550056  16127.116380  20370.673766  29931.470715   \n",
       "1973      355.788210  16352.902703  21825.543720  35393.583566   \n",
       "1974      351.708069  16901.495224  21150.496237  39125.445624   \n",
       "1975      375.083373  17055.638892  21458.049820  33860.391587   \n",
       "1976      372.642613  16732.882702  22146.595979  37905.118963   \n",
       "1977      390.636680  16531.778366  22897.185145  38557.214887   \n",
       "1978      403.633544  17085.997506  23887.179963  34657.430841   \n",
       "1979      373.832253  17686.527449  24985.791175  36676.579817   \n",
       "1980      389.926284  17555.374050  25489.166212  36518.045017   \n",
       "1981      403.878043  18128.197415  26358.347887  34979.348502   \n",
       "1982      408.324880  18161.403359  27064.101587  26017.961881   \n",
       "1983      428.069168  18433.073989  27703.018898  20512.784349   \n",
       "1984      434.366505  18363.527461  28756.637828  18427.004739   \n",
       "1985      446.999769  18770.866975  30391.524934  15734.737370   \n",
       "1986      458.086675  19257.140369  31062.093246  17509.058826   \n",
       "1987      465.982729  20301.478541  32179.290398  15608.752709   \n",
       "1988      500.013277  20548.910736  34332.258170  16921.548686   \n",
       "1989      518.698716  20319.362638  36028.153127  16194.514299   \n",
       "1990      536.162786  21141.296987  37906.163702  18002.738727   \n",
       "1991      530.894738  21444.754977  39044.927269  20040.495704   \n",
       "1992      548.895784  22323.835383  39267.120243  20226.854221   \n",
       "1993      563.749688  22633.175722  39237.326692  19413.782933   \n",
       "1994      589.708788  23692.851616  39441.570730  19041.058398   \n",
       "1995      622.303683  24592.248121  40368.705109  18648.856867   \n",
       "1996      656.697144  25230.310306  41514.862495  18744.780515   \n",
       "1997      670.610122  25508.753756  41861.911038  18588.304345   \n",
       "1998      699.068855  25970.833783  41277.077476  18763.585928   \n",
       "1999      747.252036  26233.305585  41097.961196  17690.917569   \n",
       "2000      762.313341  27636.340774  42169.697876  18263.230082   \n",
       "2001      785.344628  26999.365062  42239.113089  17585.390455   \n",
       "2002      801.507933  26503.843351  42190.778633  16619.434659   \n",
       "2003      850.293265  26229.399767  42743.993910  17954.949794   \n",
       "2004      902.905794  26947.444502  43671.680357  18822.730041   \n",
       "2005      971.229761  27570.945566  44393.662740  19309.312067   \n",
       "2006     1044.893940  28499.325927  44995.521508  19304.550232   \n",
       "2007     1130.090071  29614.075151  45687.345751  19136.159234   \n",
       "2008     1156.932527  29962.005614  45165.887162  19792.720384   \n",
       "2009     1237.339786  29658.522632  42724.534911  18861.109998   \n",
       "2010     1345.770153  30642.940617  44507.676386  19259.587257   \n",
       "2011     1416.403391  31482.997157  44538.726191  20575.497951   \n",
       "2012     1474.967674  31507.478563  45276.874335  21056.347147   \n",
       "2013     1550.142230  32196.358171  46249.209589  21005.012123   \n",
       "2014     1646.781252  32661.294135  46484.155267  21183.464888   \n",
       "2015     1758.043376  32993.314276  47163.494211  21507.955693   \n",
       "2016     1861.491029  33677.461945  47660.893039  21395.359780   \n",
       "\n",
       "country  United Kingdom  United States  \n",
       "year                                    \n",
       "1970       17934.191423   23309.620946  \n",
       "1971       18481.731208   23775.276923  \n",
       "1972       19211.556510   24760.145377  \n",
       "1973       20422.489067   25908.912802  \n",
       "1974       19906.842338   25540.501003  \n",
       "1975       19613.863820   25239.919906  \n",
       "1976       20189.605653   26347.809282  \n",
       "1977       20689.657017   27286.251514  \n",
       "1978       21557.849741   28500.240457  \n",
       "1979       22344.368377   29082.593778  \n",
       "1980       21865.128077   28734.399260  \n",
       "1981       21688.996341   29191.999488  \n",
       "1982       22132.998247   28362.494616  \n",
       "1983       23059.570217   29406.257469  \n",
       "1984       23547.164676   31268.975645  \n",
       "1985       24479.425426   32306.833057  \n",
       "1986       25185.830402   33133.695444  \n",
       "1987       26465.995011   33975.654795  \n",
       "1988       27925.081236   35083.969043  \n",
       "1989       28568.126539   36033.330203  \n",
       "1990       28691.292875   36312.414183  \n",
       "1991       28291.921960   35803.868421  \n",
       "1992       28321.034526   36566.173770  \n",
       "1993       28967.155045   37078.049684  \n",
       "1994       30014.581399   38104.972468  \n",
       "1995       30674.605547   38677.715088  \n",
       "1996       31373.341190   39681.519858  \n",
       "1997       32556.316695   40965.846645  \n",
       "1998       33480.166013   42292.891201  \n",
       "1999       34442.107095   43768.884993  \n",
       "2000       35576.766847   45055.817918  \n",
       "2001       36341.709750   45047.487198  \n",
       "2002       37077.648355   45428.645678  \n",
       "2003       38132.841086   46304.036090  \n",
       "2004       38813.021694   47614.279862  \n",
       "2005       39740.902921   48755.616061  \n",
       "2006       40418.747305   49575.401014  \n",
       "2007       41050.405926   49979.533843  \n",
       "2008       40536.134857   49364.644550  \n",
       "2009       38545.915816   47575.608563  \n",
       "2010       38893.018494   48373.878816  \n",
       "2011       39150.756019   48783.468587  \n",
       "2012       39455.412136   49497.585853  \n",
       "2013       39996.501902   49976.628767  \n",
       "2014       40908.745855   50881.106863  \n",
       "2015       41536.919115   51956.583468  \n",
       "2016       42039.736364   52364.244025  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "countries = ['BR', 'CA', 'CN', 'FR', 'DE', 'IN', 'IL', 'JP', 'SA', 'GB', 'US',]\n",
    "dat = wb.download(indicator='NY.GDP.PCAP.KD',\n",
    "        country=countries, start=1970, end=2016)\n",
    "df = dat.unstack().T\n",
    "df.index = df.index.droplevel(0).astype(int)\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "搭建神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (rnn): LSTM(1, 5)\n",
       "  (fc): Linear(in_features=5, out_features=1, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Net(torch.nn.Module):\n",
    "    \n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        super(Net, self).__init__()\n",
    "        self.rnn = torch.nn.LSTM(input_size, hidden_size)\n",
    "        self.fc = torch.nn.Linear(hidden_size, 1)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = x[:, :, None]\n",
    "        x, _ = self.rnn(x)\n",
    "        x = self.fc(x)\n",
    "        x = x[:, :, 0]\n",
    "        return x\n",
    "\n",
    "net = Net(input_size=1, hidden_size=5)\n",
    "net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集长度 = 30, 测试集长度 = 16\n",
      "第0次迭代: loss (训练集) = 1.4283555746078491, loss (测试集) = 1.9456963539123535\n",
      "第500次迭代: loss (训练集) = 0.051086507737636566, loss (测试集) = 0.010078947991132736\n",
      "第1000次迭代: loss (训练集) = 0.01944635808467865, loss (测试集) = 0.002939914120361209\n",
      "第1500次迭代: loss (训练集) = 0.007834377698600292, loss (测试集) = 0.0010365095222368836\n",
      "第2000次迭代: loss (训练集) = 0.004264211747795343, loss (测试集) = 0.0005575661198236048\n",
      "第2500次迭代: loss (训练集) = 0.002953553106635809, loss (测试集) = 0.0005140299326740205\n",
      "第3000次迭代: loss (训练集) = 0.0023847392294555902, loss (测试集) = 0.0005197693244554102\n",
      "第3500次迭代: loss (训练集) = 0.002063404768705368, loss (测试集) = 0.0005044733406975865\n",
      "第4000次迭代: loss (训练集) = 0.001874488778412342, loss (测试集) = 0.0004906203248538077\n",
      "第4500次迭代: loss (训练集) = 0.0017608855850994587, loss (测试集) = 0.0004924260429106653\n",
      "第5000次迭代: loss (训练集) = 0.001682320493273437, loss (测试集) = 0.0005076289526186883\n",
      "第5500次迭代: loss (训练集) = 0.0016172596951946616, loss (测试集) = 0.0005304039223119617\n",
      "第6000次迭代: loss (训练集) = 0.0015541095053777099, loss (测试集) = 0.0005540259298868477\n",
      "第6500次迭代: loss (训练集) = 0.0014825011603534222, loss (测试集) = 0.0005765268579125404\n",
      "第7000次迭代: loss (训练集) = 0.0013840177562087774, loss (测试集) = 0.0006041080923750997\n",
      "第7500次迭代: loss (训练集) = 0.001274571055546403, loss (测试集) = 0.0006047789938747883\n",
      "第8000次迭代: loss (训练集) = 0.0011925814906135201, loss (测试集) = 0.0005529309855774045\n",
      "第8500次迭代: loss (训练集) = 0.0011211351957172155, loss (测试集) = 0.0005017257644794881\n",
      "第9000次迭代: loss (训练集) = 0.0010540310759097338, loss (测试集) = 0.00048208076623268425\n",
      "第9500次迭代: loss (训练集) = 0.0009954761480912566, loss (测试集) = 0.00048860814422369\n",
      "第10000次迭代: loss (训练集) = 0.0009423759765923023, loss (测试集) = 0.0004873224243056029\n"
     ]
    }
   ],
   "source": [
    "# 数据归一化\n",
    "df_scaled = df / df.loc[2000]\n",
    "\n",
    "# 确定训练集和测试集\n",
    "years = df.index\n",
    "train_seq_len = sum((years >= 1971) & (years <= 2000))\n",
    "test_seq_len = sum(years > 2000)\n",
    "print ('训练集长度 = {}, 测试集长度 = {}'.format(\n",
    "        train_seq_len, test_seq_len))\n",
    "\n",
    "# 确定训练使用的特征和标签\n",
    "inputs = torch.tensor(df_scaled.iloc[:-1].values, dtype=torch.float32)\n",
    "labels = torch.tensor(df_scaled.iloc[1:].values, dtype=torch.float32)\n",
    "\n",
    "# 训练网络\n",
    "criterion = torch.nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(net.parameters())\n",
    "for step in range(10001):\n",
    "    if step:\n",
    "        optimizer.zero_grad()\n",
    "        train_loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    preds = net(inputs)\n",
    "    train_preds = preds[:train_seq_len]\n",
    "    train_labels = labels[:train_seq_len]\n",
    "    train_loss = criterion(train_preds, train_labels)\n",
    "    \n",
    "    test_preds = preds[train_seq_len:]\n",
    "    test_labels = labels[train_seq_len:]\n",
    "    test_loss = criterion(test_preds, test_labels)\n",
    "    \n",
    "    if step % 500 == 0:\n",
    "        print ('第{}次迭代: loss (训练集) = {}, loss (测试集) = {}'.format(\n",
    "                step, train_loss, test_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>country</th>\n",
       "      <th>Brazil</th>\n",
       "      <th>Canada</th>\n",
       "      <th>China</th>\n",
       "      <th>France</th>\n",
       "      <th>Germany</th>\n",
       "      <th>India</th>\n",
       "      <th>Israel</th>\n",
       "      <th>Japan</th>\n",
       "      <th>Saudi Arabia</th>\n",
       "      <th>United Kingdom</th>\n",
       "      <th>United States</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>year</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2001</th>\n",
       "      <td>8728.474609</td>\n",
       "      <td>44115.812500</td>\n",
       "      <td>1803.890381</td>\n",
       "      <td>38758.660156</td>\n",
       "      <td>38115.324219</td>\n",
       "      <td>772.356262</td>\n",
       "      <td>27742.117188</td>\n",
       "      <td>41971.281250</td>\n",
       "      <td>17978.492188</td>\n",
       "      <td>35744.425781</td>\n",
       "      <td>45331.175781</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2002</th>\n",
       "      <td>8791.182617</td>\n",
       "      <td>44121.328125</td>\n",
       "      <td>1922.057007</td>\n",
       "      <td>39006.152344</td>\n",
       "      <td>38580.707031</td>\n",
       "      <td>783.310364</td>\n",
       "      <td>27209.820312</td>\n",
       "      <td>42332.281250</td>\n",
       "      <td>17640.318359</td>\n",
       "      <td>36342.664062</td>\n",
       "      <td>45034.355469</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2003</th>\n",
       "      <td>8888.078125</td>\n",
       "      <td>44490.312500</td>\n",
       "      <td>2051.746582</td>\n",
       "      <td>38855.230469</td>\n",
       "      <td>38287.113281</td>\n",
       "      <td>793.614380</td>\n",
       "      <td>26333.099609</td>\n",
       "      <td>42042.031250</td>\n",
       "      <td>16624.966797</td>\n",
       "      <td>36820.015625</td>\n",
       "      <td>45028.613281</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2004</th>\n",
       "      <td>8859.866211</td>\n",
       "      <td>44770.628906</td>\n",
       "      <td>2198.029053</td>\n",
       "      <td>38760.089844</td>\n",
       "      <td>37806.500000</td>\n",
       "      <td>837.694824</td>\n",
       "      <td>26041.751953</td>\n",
       "      <td>42403.218750</td>\n",
       "      <td>17823.111328</td>\n",
       "      <td>37667.347656</td>\n",
       "      <td>45898.242188</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2005</th>\n",
       "      <td>9210.673828</td>\n",
       "      <td>45581.945312</td>\n",
       "      <td>2339.192139</td>\n",
       "      <td>39543.230469</td>\n",
       "      <td>38244.601562</td>\n",
       "      <td>887.680359</td>\n",
       "      <td>26928.919922</td>\n",
       "      <td>43388.410156</td>\n",
       "      <td>19140.501953</td>\n",
       "      <td>38211.289062</td>\n",
       "      <td>47274.554688</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2006</th>\n",
       "      <td>9422.847656</td>\n",
       "      <td>46537.699219</td>\n",
       "      <td>2496.956055</td>\n",
       "      <td>39985.117188</td>\n",
       "      <td>38723.804688</td>\n",
       "      <td>939.492920</td>\n",
       "      <td>27781.509766</td>\n",
       "      <td>44090.531250</td>\n",
       "      <td>19460.792969</td>\n",
       "      <td>38858.921875</td>\n",
       "      <td>48266.789062</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2007</th>\n",
       "      <td>9592.635742</td>\n",
       "      <td>47181.664062</td>\n",
       "      <td>2679.909912</td>\n",
       "      <td>40467.410156</td>\n",
       "      <td>40128.886719</td>\n",
       "      <td>985.826965</td>\n",
       "      <td>28578.460938</td>\n",
       "      <td>44437.285156</td>\n",
       "      <td>19059.839844</td>\n",
       "      <td>39318.265625</td>\n",
       "      <td>48683.855469</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2008</th>\n",
       "      <td>9980.925781</td>\n",
       "      <td>47384.859375</td>\n",
       "      <td>2889.713379</td>\n",
       "      <td>41026.304688</td>\n",
       "      <td>41431.019531</td>\n",
       "      <td>1033.875000</td>\n",
       "      <td>29471.312500</td>\n",
       "      <td>44868.574219</td>\n",
       "      <td>18686.330078</td>\n",
       "      <td>39690.921875</td>\n",
       "      <td>48707.433594</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2009</th>\n",
       "      <td>10308.732422</td>\n",
       "      <td>47094.582031</td>\n",
       "      <td>2964.640137</td>\n",
       "      <td>40746.019531</td>\n",
       "      <td>41567.589844</td>\n",
       "      <td>1020.524536</td>\n",
       "      <td>29582.595703</td>\n",
       "      <td>44243.945312</td>\n",
       "      <td>19352.714844</td>\n",
       "      <td>39017.917969</td>\n",
       "      <td>47921.152344</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2010</th>\n",
       "      <td>10042.784180</td>\n",
       "      <td>45130.597656</td>\n",
       "      <td>3082.189941</td>\n",
       "      <td>39192.152344</td>\n",
       "      <td>38951.582031</td>\n",
       "      <td>1063.351318</td>\n",
       "      <td>28979.738281</td>\n",
       "      <td>41773.062500</td>\n",
       "      <td>18638.841797</td>\n",
       "      <td>37010.519531</td>\n",
       "      <td>46168.046875</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2011</th>\n",
       "      <td>10545.157227</td>\n",
       "      <td>45927.894531</td>\n",
       "      <td>3271.759277</td>\n",
       "      <td>39668.851562</td>\n",
       "      <td>40243.339844</td>\n",
       "      <td>1142.086304</td>\n",
       "      <td>29782.060547</td>\n",
       "      <td>43372.574219</td>\n",
       "      <td>18805.349609</td>\n",
       "      <td>37405.941406</td>\n",
       "      <td>47021.449219</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2012</th>\n",
       "      <td>10890.395508</td>\n",
       "      <td>47562.371094</td>\n",
       "      <td>3421.420898</td>\n",
       "      <td>40790.527344</td>\n",
       "      <td>43369.804688</td>\n",
       "      <td>1157.708618</td>\n",
       "      <td>30800.066406</td>\n",
       "      <td>44300.109375</td>\n",
       "      <td>20275.322266</td>\n",
       "      <td>38327.886719</td>\n",
       "      <td>48051.511719</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2013</th>\n",
       "      <td>10795.853516</td>\n",
       "      <td>47796.234375</td>\n",
       "      <td>3534.430664</td>\n",
       "      <td>40710.234375</td>\n",
       "      <td>43351.359375</td>\n",
       "      <td>1157.848145</td>\n",
       "      <td>30651.146484</td>\n",
       "      <td>44771.492188</td>\n",
       "      <td>20806.056641</td>\n",
       "      <td>38692.660156</td>\n",
       "      <td>48676.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2014</th>\n",
       "      <td>10842.416016</td>\n",
       "      <td>47898.992188</td>\n",
       "      <td>3644.565918</td>\n",
       "      <td>40405.414062</td>\n",
       "      <td>42414.113281</td>\n",
       "      <td>1193.274658</td>\n",
       "      <td>30945.763672</td>\n",
       "      <td>45415.226562</td>\n",
       "      <td>20308.781250</td>\n",
       "      <td>38956.703125</td>\n",
       "      <td>48838.925781</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2015</th>\n",
       "      <td>10769.652344</td>\n",
       "      <td>48525.160156</td>\n",
       "      <td>3731.458252</td>\n",
       "      <td>40511.148438</td>\n",
       "      <td>42770.261719</td>\n",
       "      <td>1246.998779</td>\n",
       "      <td>31350.187500</td>\n",
       "      <td>45456.238281</td>\n",
       "      <td>20191.625000</td>\n",
       "      <td>39657.914062</td>\n",
       "      <td>49442.578125</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2016</th>\n",
       "      <td>10261.981445</td>\n",
       "      <td>48531.710938</td>\n",
       "      <td>3799.287598</td>\n",
       "      <td>40840.281250</td>\n",
       "      <td>43376.960938</td>\n",
       "      <td>1299.196533</td>\n",
       "      <td>31561.748047</td>\n",
       "      <td>45860.765625</td>\n",
       "      <td>20580.929688</td>\n",
       "      <td>40138.867188</td>\n",
       "      <td>50372.425781</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "country        Brazil        Canada        China        France       Germany  \\\n",
       "year                                                                           \n",
       "2001      8728.474609  44115.812500  1803.890381  38758.660156  38115.324219   \n",
       "2002      8791.182617  44121.328125  1922.057007  39006.152344  38580.707031   \n",
       "2003      8888.078125  44490.312500  2051.746582  38855.230469  38287.113281   \n",
       "2004      8859.866211  44770.628906  2198.029053  38760.089844  37806.500000   \n",
       "2005      9210.673828  45581.945312  2339.192139  39543.230469  38244.601562   \n",
       "2006      9422.847656  46537.699219  2496.956055  39985.117188  38723.804688   \n",
       "2007      9592.635742  47181.664062  2679.909912  40467.410156  40128.886719   \n",
       "2008      9980.925781  47384.859375  2889.713379  41026.304688  41431.019531   \n",
       "2009     10308.732422  47094.582031  2964.640137  40746.019531  41567.589844   \n",
       "2010     10042.784180  45130.597656  3082.189941  39192.152344  38951.582031   \n",
       "2011     10545.157227  45927.894531  3271.759277  39668.851562  40243.339844   \n",
       "2012     10890.395508  47562.371094  3421.420898  40790.527344  43369.804688   \n",
       "2013     10795.853516  47796.234375  3534.430664  40710.234375  43351.359375   \n",
       "2014     10842.416016  47898.992188  3644.565918  40405.414062  42414.113281   \n",
       "2015     10769.652344  48525.160156  3731.458252  40511.148438  42770.261719   \n",
       "2016     10261.981445  48531.710938  3799.287598  40840.281250  43376.960938   \n",
       "\n",
       "country        India        Israel         Japan  Saudi Arabia  \\\n",
       "year                                                             \n",
       "2001      772.356262  27742.117188  41971.281250  17978.492188   \n",
       "2002      783.310364  27209.820312  42332.281250  17640.318359   \n",
       "2003      793.614380  26333.099609  42042.031250  16624.966797   \n",
       "2004      837.694824  26041.751953  42403.218750  17823.111328   \n",
       "2005      887.680359  26928.919922  43388.410156  19140.501953   \n",
       "2006      939.492920  27781.509766  44090.531250  19460.792969   \n",
       "2007      985.826965  28578.460938  44437.285156  19059.839844   \n",
       "2008     1033.875000  29471.312500  44868.574219  18686.330078   \n",
       "2009     1020.524536  29582.595703  44243.945312  19352.714844   \n",
       "2010     1063.351318  28979.738281  41773.062500  18638.841797   \n",
       "2011     1142.086304  29782.060547  43372.574219  18805.349609   \n",
       "2012     1157.708618  30800.066406  44300.109375  20275.322266   \n",
       "2013     1157.848145  30651.146484  44771.492188  20806.056641   \n",
       "2014     1193.274658  30945.763672  45415.226562  20308.781250   \n",
       "2015     1246.998779  31350.187500  45456.238281  20191.625000   \n",
       "2016     1299.196533  31561.748047  45860.765625  20580.929688   \n",
       "\n",
       "country  United Kingdom  United States  \n",
       "year                                    \n",
       "2001       35744.425781   45331.175781  \n",
       "2002       36342.664062   45034.355469  \n",
       "2003       36820.015625   45028.613281  \n",
       "2004       37667.347656   45898.242188  \n",
       "2005       38211.289062   47274.554688  \n",
       "2006       38858.921875   48266.789062  \n",
       "2007       39318.265625   48683.855469  \n",
       "2008       39690.921875   48707.433594  \n",
       "2009       39017.917969   47921.152344  \n",
       "2010       37010.519531   46168.046875  \n",
       "2011       37405.941406   47021.449219  \n",
       "2012       38327.886719   48051.511719  \n",
       "2013       38692.660156   48676.000000  \n",
       "2014       38956.703125   48838.925781  \n",
       "2015       39657.914062   49442.578125  \n",
       "2016       40138.867188   50372.425781  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds = net(inputs)\n",
    "df_pred_scaled = pd.DataFrame(preds.detach().numpy(),\n",
    "        index=years[1:], columns=df.columns)\n",
    "df_pred = df_pred_scaled * df.loc[2000]\n",
    "df_pred.loc[2001:]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
